{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Mnist数据集逻辑回归分类任务"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "需要的工具包"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/shoremei/anaconda3/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-1-cf63bf3938a4>:4: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n",
      "WARNING:tensorflow:From /home/shoremei/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please write your own downloading logic.\n",
      "WARNING:tensorflow:From /home/shoremei/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.data to implement this functionality.\n",
      "Extracting data/train-images-idx3-ubyte.gz\n",
      "WARNING:tensorflow:From /home/shoremei/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.data to implement this functionality.\n",
      "Extracting data/train-labels-idx1-ubyte.gz\n",
      "WARNING:tensorflow:From /home/shoremei/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.one_hot on tensors.\n",
      "Extracting data/t10k-images-idx3-ubyte.gz\n",
      "Extracting data/t10k-labels-idx1-ubyte.gz\n",
      "WARNING:tensorflow:From /home/shoremei/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n",
      "done\n"
     ]
    }
   ],
   "source": [
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "import tensorflow as tf\n",
    "\n",
    "mnist = input_data.read_data_sets('data/', one_hot=True)\n",
    "print(\"done\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "设置我们的参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "numClasses = 10\n",
    "inputSize = 784  \n",
    "trainingIterations = 50000\n",
    "batchSize = 64"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "指定好x和y的大小"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = tf.placeholder(tf.float32, shape = [None, inputSize])\n",
    "y = tf.placeholder(tf.float32, shape = [None, numClasses])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "参数初始化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/shoremei/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Colocations handled automatically by placer.\n"
     ]
    }
   ],
   "source": [
    "W1 = tf.Variable(tf.random_normal([inputSize, numClasses], stddev=0.1))\n",
    "B1 = tf.Variable(tf.constant(0.1), [numClasses])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "构造模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred = tf.nn.softmax(tf.matmul(X, W1) + B1)\n",
    "\n",
    "loss = tf.reduce_mean(tf.square(y - y_pred))\n",
    "opt = tf.train.GradientDescentOptimizer(learning_rate = .05).minimize(loss)\n",
    "\n",
    "correct_prediction = tf.equal(tf.argmax(y_pred,1), tf.argmax(y,1))\n",
    "accuracy = tf.reduce_mean(tf.cast(correct_prediction, \"float\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "sess = tf.Session()\n",
    "init = tf.global_variables_initializer()\n",
    "sess.run(init)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "迭代计算"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step 0, training accuracy 0.125\n",
      "step 1000, training accuracy 0.515625\n",
      "step 2000, training accuracy 0.6875\n",
      "step 3000, training accuracy 0.75\n",
      "step 4000, training accuracy 0.828125\n",
      "step 5000, training accuracy 0.8125\n",
      "step 6000, training accuracy 0.859375\n",
      "step 7000, training accuracy 0.75\n",
      "step 8000, training accuracy 0.859375\n",
      "step 9000, training accuracy 0.890625\n",
      "step 10000, training accuracy 0.859375\n",
      "step 11000, training accuracy 0.84375\n",
      "step 12000, training accuracy 0.921875\n",
      "step 13000, training accuracy 0.84375\n",
      "step 14000, training accuracy 0.90625\n",
      "step 15000, training accuracy 0.921875\n",
      "step 16000, training accuracy 0.90625\n",
      "step 17000, training accuracy 0.875\n",
      "step 18000, training accuracy 0.90625\n",
      "step 19000, training accuracy 0.84375\n",
      "step 20000, training accuracy 0.859375\n",
      "step 21000, training accuracy 0.921875\n",
      "step 22000, training accuracy 0.9375\n",
      "step 23000, training accuracy 0.8125\n",
      "step 24000, training accuracy 0.953125\n",
      "step 25000, training accuracy 0.921875\n",
      "step 26000, training accuracy 0.921875\n",
      "step 27000, training accuracy 0.875\n",
      "step 28000, training accuracy 0.875\n",
      "step 29000, training accuracy 0.9375\n",
      "step 30000, training accuracy 0.828125\n",
      "step 31000, training accuracy 0.90625\n",
      "step 32000, training accuracy 0.921875\n",
      "step 33000, training accuracy 0.828125\n",
      "step 34000, training accuracy 0.890625\n",
      "step 35000, training accuracy 0.890625\n",
      "step 36000, training accuracy 0.9375\n",
      "step 37000, training accuracy 0.859375\n",
      "step 38000, training accuracy 0.921875\n",
      "step 39000, training accuracy 0.890625\n",
      "step 40000, training accuracy 0.921875\n",
      "step 41000, training accuracy 0.859375\n",
      "step 42000, training accuracy 0.90625\n",
      "step 43000, training accuracy 0.875\n",
      "step 44000, training accuracy 0.921875\n",
      "step 45000, training accuracy 0.9375\n",
      "step 46000, training accuracy 0.953125\n",
      "step 47000, training accuracy 0.890625\n",
      "step 48000, training accuracy 0.890625\n",
      "step 49000, training accuracy 0.875\n"
     ]
    }
   ],
   "source": [
    "for i in range(trainingIterations):\n",
    "    batch = mnist.train.next_batch(batchSize)\n",
    "    batchInput = batch[0]\n",
    "    batchLabels = batch[1]\n",
    "    _, trainingLoss = sess.run([opt, loss], feed_dict={X: batchInput, y: batchLabels})\n",
    "    \n",
    "    if i%1000 == 0:\n",
    "        train_accuracy = accuracy.eval(session=sess, feed_dict={X: batchInput, y: batchLabels})\n",
    "        print (\"step %d, training accuracy %g\"%(i, train_accuracy))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "测试结果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test accuracy 0.9375\n"
     ]
    }
   ],
   "source": [
    "batch = mnist.test.next_batch(batchSize)\n",
    "testAccuracy = sess.run(accuracy, feed_dict={X: batch[0], y: batch[1]})\n",
    "print (\"test accuracy %g\"%(testAccuracy))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
